#include <fstream>
#include <iostream>
#include <sstream>
#include "boost/algorithm/string.hpp"
#include "DxbcContainer.h"

using namespace boolinq;
using namespace std;
using namespace SlimShader;

vector<char> ReadFileBytes(const string name)
{
	ifstream file(name, ios::binary);
	vector<char> fileContents(
		(istreambuf_iterator<char>(file)),
		istreambuf_iterator<char>());
    return fileContents;
}

string GetAsmText(string asmFile)
{
	ifstream file(asmFile);
	vector<string> asmFileLines;
	string asmFileLine;
	while (getline(file, asmFileLine))
		asmFileLines.push_back(asmFileLine);

	/* The first 5 or 6 lines contain something like:
			
	//
	// Generated by Microsoft (R) HLSL Shader Compiler 9.29.952.3111
	//
	//
	//   fxc /T vs_4_0 /Fo multiple_const_buffers.o /Fc multiple_const_buffers.asm
	//    multiple_const_buffers
	*/

	// We want to skip all that, because we can't accurately recreate the fxc command-line, and so we
	// aren't able to do a string comparison on these lines.
	int skip = 5;
	while (asmFileLines[skip] != "//")
		skip++;
	return boost::algorithm::join(from(asmFileLines)
		.skip(skip)
		.select([](string x) { boost::algorithm::trim(x); return x; })
		.toVector(), "\n");
}

void CompareAssemblyOutput(string file, DxbcContainer& container)
{
	// Read .asm file.
	string asmFile = file + ".asm";
	auto asmFileText = GetAsmText(asmFile);

	// Get ASM output from decompiled bytecode.
	std::stringstream stream;
	stream << container;

	// Ignore first 5 lines - they contain the compiler-specific headers.
	vector<string> decompiledAsmLines;
	string decompiledAsmLine;
	while (getline(stream, decompiledAsmLine))
		decompiledAsmLines.push_back(decompiledAsmLine);
	string decompiledAsmText = boost::algorithm::join(from(decompiledAsmLines)
		.skip(5)
		.select([](string line) { boost::algorithm::trim(line); return line; })
		.toVector(), "\n");

	// Compare strings.
	if (decompiledAsmText.compare(asmFileText) != 0)
	{
		int differIndex = 0;
		for (size_t i = 0; i < decompiledAsmText.size(); i++)
		{
			if (decompiledAsmText.at(i) != asmFileText.at(i))
			{
				differIndex = i;
				break;
			}
		}
		const int numContextCharacters = 40;
		string message = "Expected string length " + to_string(asmFileText.size())
			+ " but was " + to_string(decompiledAsmText.size())
			+ ". Strings differ at index " + to_string(differIndex) + ".\n"
			+ "Expected: \"" + asmFileText.substr(differIndex - numContextCharacters, numContextCharacters * 2) + "\"\n"
			+ "But was:  \"" + decompiledAsmText.substr(differIndex - numContextCharacters, numContextCharacters * 2) + "\"";
		throw runtime_error("Decompiled version does not match fxc.exe output.\n\n" + message);
	}
}

void TestFile(string file)
{
	file = "../shaders/" + file;
	string binaryFile = file + ".o";

	auto binaryFileBytes = ReadFileBytes(binaryFile);
	auto container = DxbcContainer::Parse(binaryFileBytes);

	CompareAssemblyOutput(file, container);
}

int main(int argc, char* argv[])
{
	try
	{
		TestFile("FxDis/test_PS");
		TestFile("FxDis/test_VS");
		TestFile("HlslCrossCompiler/ds5/basic");
		TestFile("HlslCrossCompiler/gs4/CubeMap_Inst");
		TestFile("HlslCrossCompiler/hs5/basic");
		TestFile("HlslCrossCompiler/ps4/fxaa");
		TestFile("HlslCrossCompiler/ps4/primID");
		TestFile("HlslCrossCompiler/ps5/conservative_depth_ge");
		TestFile("HlslCrossCompiler/ps5/interface_arrays");
		TestFile("HlslCrossCompiler/ps5/interfaces");
		TestFile("HlslCrossCompiler/ps5/sample");
		TestFile("HlslCrossCompiler/vs4/mov");
		TestFile("HlslCrossCompiler/vs4/multiple_const_buffers");
		TestFile("HlslCrossCompiler/vs4/switch");
		TestFile("HlslCrossCompiler/vs5/any");
		TestFile("HlslCrossCompiler/vs5/const_temp");
		TestFile("HlslCrossCompiler/vs5/mad_imm");
		TestFile("HlslCrossCompiler/vs5/mov");
		TestFile("HlslCrossCompiler/vs5/sincos");
		TestFile("Sdk/Direct3D11/AdaptiveTessellationCS40/TessellatorCS40_EdgeFactorCS");
		TestFile("Sdk/Direct3D11/AdaptiveTessellationCS40/TessellatorCS40_NumVerticesIndicesCS");
		TestFile("Sdk/Direct3D11/AdaptiveTessellationCS40/TessellatorCS40_ScatterIDCS");
		TestFile("Sdk/Direct3D11/AdaptiveTessellationCS40/TessellatorCS40_TessellateIndicesCS");
		TestFile("Sdk/Direct3D11/AdaptiveTessellationCS40/TessellatorCS40_TessellateVerticesCS");
		TestFile("Sdk/Direct3D11/BasicCompute11/BasicCompute11");
		//TestFile("s/Sdk/Direct3D11/BasicHLSL11/BasicHLSLPS"); // Can't parse SDBG chunk type yet.
		//TestFile("s/Sdk/Direct3D11/BasicHLSL11/BasicHLSLVS");
		TestFile("Sdk/Direct3D11/BC6HBC7EncoderDecoder11/BC6HDecode");
		TestFile("Sdk/Direct3D11/BC6HBC7EncoderDecoder11/BC7Decode");
		TestFile("Sdk/Direct3D11/BC6HBC7EncoderDecoder11/BC7Encode");
		TestFile("Sdk/Direct3D11/DynamicShaderLinkage11/DynamicShaderLinkage11_PS");
		TestFile("Sdk/Direct3D11/NBodyGravityCS11/NBodyGravityCS11");
		TestFile("Sdk/Direct3D11/NBodyGravityCS11/ParticleDrawGS");
		TestFile("Sdk/Direct3D11/NBodyGravityCS11/ParticleDrawPS");
		TestFile("Sdk/Direct3D11/NBodyGravityCS11/ParticleDrawVS");
		TestFile("Sdk/Direct3D11/SimpleBezier11/SimpleBezier11DS");
		TestFile("Sdk/Direct3D11/SimpleBezier11/SimpleBezier11HS");
		TestFile("Sdk/Direct3D11/SimpleBezier11/SimpleBezier11PS");
		TestFile("Sdk/Direct3D11/SimpleBezier11/SimpleBezier11VS");
	}
	catch (const runtime_error& e)
	{
		cout << endl << endl << e.what() << endl;
	}
	return 0;
}